#pytorch MNIST GPU-CPU
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
import time
from torch.autograd import Variable
#torch.manual_seed(1)

#超参
use_gpu = True
EPOCH = 1
BACTH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False

train_data = torchvision.datasets.MNIST(
    root = './mnist',
    train=True,
    transform=torchvision.transforms.ToTensor(),#转换PIL.image or numpy.ndarray成
                                                #torch.FloatTensor(C H W),训练的时候normalize成[0,1]之间
    #download=DOWNLOAD_MNIST,
    download=True,
)
test_data = torchvision.datasets.MNIST(root='./mnist',train = False)

train_loader = Data.DataLoader(dataset=train_data,batch_size=BACTH_SIZE,shuffle=True)
#测试前两百个，shape from （2000，28，28）to（2000，1，28，28），value in range（0，1）
test_x = torch.unsqueeze(test_data.test_data,dim = 1).type(torch.FloatTensor)[:2000]/255
test_y = test_data.test_labels[:2000]

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Sequential(#input shape(1,28,28)
            nn.Conv2d(
                in_channels=1,  # input height
                out_channels=16,  # n_filters
                kernel_size=5,  # filter size
                stride=1,  # filter movement/step
                padding=2,  # 如果想要 con2d 出来的图片长宽没有变化, padding=(kernel_size-1)/2 当 stride=1
            ),  # output shape (16, 28, 28)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),#output shape(16,14,14)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=16,
                out_channels=32,
                kernel_size=5,
                stride=1,
                padding=2,
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),#output shape(32,7,7)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=5,
                stride=1,
                padding=2,
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,padding=1),  # output shape(64,4,4)
        )
        self.out = nn.Linear(64*4*4,10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)  # 展平多维的卷积图成 (batch_size, 32 * 7 * 7)
        output = self.out(x)
        return output
cnn = CNN()
if use_gpu:
    cnn =cnn.cuda()
print(cnn)
optimizer = torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func = nn.CrossEntropyLoss()

for epoch in range(EPOCH):
    start = time.time()
    for step,(b_x,b_y) in enumerate(train_loader):
        b_x = Variable(b_x,requires_grad = True)
        b_y = Variable(b_y,requires_grad = False)
        if use_gpu:
            b_x = b_x.cuda()
            b_y = b_y.cuda()
            test_x = test_x.cuda()

        output = cnn(b_x)
        loss = loss_func(output,b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 50 == 0:
            test_output = cnn(test_x)
            if use_gpu:
                pred_y = torch.max(test_output.cpu(),1)[1].data.numpy()
                accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum())/float(test_y.size(0))
                print("Epoch:%s"%str(epoch),"|train loss:%.4f"%loss.cpu().data.numpy(),"|test accuracy:%.2f"%accuracy)
            else:
                pred_y = torch.max(test_output, 1)[1].data.numpy()
                accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
                print("Epoch:%s" % str(epoch), "|train loss:%.4f" % loss.data.numpy(),
                      "|test accuracy:%.2f" % accuracy)
    duration = time.time() - start
    print("Epoch:%s"%epoch,"Training duration:%.4f"%duration)
#print(cnn)
